#! /usr/bin/env python3

import argparse
import logging
import time
from datetime import datetime
import jsonrpcclient
import psutil
import numpy
from decimal import Decimal


# disable jsonrpcclient verbose logging
logging.getLogger("jsonrpcclient.client.request").setLevel(logging.WARNING)
logging.getLogger("jsonrpcclient.client.response").setLevel(logging.WARNING)


def now():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")


def fstr(v: float):
    return "{:.2f}".format(v)


def basic(client, ip):
    s = client.send(jsonrpcclient.Request("getStats"))
    msg = "QuarkChain Cluster Stats\n\n"
    msg += "CPU:                {}\n".format(psutil.cpu_count())
    msg += "Memory:             {} GB\n".format(
        int(psutil.virtual_memory().total / 1024 / 1024 / 1024)
    )
    msg += "IP:                 {}\n".format(ip)
    msg += "Shards:             {}\n".format(s["shardSize"])
    msg += "Network Id:         {}\n".format(s["networkId"])
    msg += "Peers:              {}\n".format(", ".join(s["peers"]))
    return msg


def stats(client):
    s = client.send(jsonrpcclient.Request("getStats"))
    return {
        "time": now(),
        "syncing": str(s["syncing"]),
        "tps": fstr(s["txCount60s"] / 60),
        "pendingTx": str(s["pendingTxCount"]),
        "confirmedTx": str(s["totalTxCount"]),
        "bps": fstr(s["blockCount60s"] / 60),
        "sbps": fstr(s["staleBlockCount60s"] / 60),
        "cpu": fstr(numpy.mean([s["cpus"]])),
        "root": str(s["rootHeight"]),
        "shards": str([shard.get("height", -1) for shard in s["shards"]]),
    }


def query_stats(client, args):
    if args.verbose:
        format = "{time:20} {syncing:>8} {tps:>5} {pendingTx:>10} {confirmedTx:>10} {bps:>9} {sbps:>9} {cpu:>9} {root:>7} {shards}"
    else:
        format = "{time:20} {syncing:>8} {root:>7} {shards}"
    print(
        format.format(
            time="Timestamp",
            syncing="Syncing",
            tps="TPS",
            pendingTx="Pend.TX",
            confirmedTx="Conf.TX",
            bps="BPS",
            sbps="SBPS",
            cpu="CPU",
            root="ROOT",
            shards="SHARDS",
        )
    )

    while True:
        print(format.format(**stats(client)))
        time.sleep(args.interval)


def format_qkc(qkc: Decimal):
    if qkc == 0:
        return "0"
    return "{:.18f}".format(qkc).rstrip("0").rstrip(".")


def query_address(client, args):
    address_hex = args.address.lower().lstrip("0").lstrip("x")
    assert len(address_hex) == 48

    print("Querying balances for 0x{}".format(address_hex))
    format = "{time:20} {total:>18}    {shards}"
    print(format.format(time="Timestamp", total="Total", shards="Shards"))

    while True:
        data = client.send(
            jsonrpcclient.Request(
                "getAccountData", address="0x" + address_hex, include_shards=True
            )
        )
        shards_wei = [int(s["balance"], 16) for s in data["shards"]]
        total = format_qkc(Decimal(sum(shards_wei)) / (10 ** 18))
        shards_qkc = ", ".join(
            [format_qkc(Decimal(d) / (10 ** 18)) for d in shards_wei]
        )
        print(format.format(time=now(), total=total, shards=shards_qkc))
        time.sleep(args.interval)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ip", default="localhost", type=str, help="Cluster IP")
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        default=False,
        dest="verbose",
        help="Show more details",
    )
    parser.add_argument(
        "-i", "--interval", default=10, type=int, help="Query interval in second"
    )
    parser.add_argument(
        "-a",
        "--address",
        default="",
        type=str,
        help="Query account balance if a QKC address is provided",
    )
    args = parser.parse_args()

    private_endpoint = "http://{}:38491".format(args.ip)
    private_client = jsonrpcclient.HTTPClient(private_endpoint)
    public_endpoint = "http://{}:38391".format(args.ip)
    public_client = jsonrpcclient.HTTPClient(public_endpoint)

    print(basic(private_client, args.ip))

    if args.address:
        query_address(public_client, args)
    else:
        query_stats(private_client, args)


if __name__ == "__main__":
    # query stats from local cluster
    main()
